"""========================================================================
This file inherits the encoder, decoder, and the channel modules used for joint source-channel coding of text data. Also incldues the system and the train file.
It expands the channel coding aspect of the network by passing the outputs of the recurrent nets in the JSC architecture through a multi-layered residual network.


@author: Nariman Farsad, Milind Rao (original authors) and Avoy Datta (channel coding aspect)
@copyright: Copyright 2018
========================================================================"""
import os
import tensorflow as tf
print(tf.__version__)
import numpy as np
import pickle
import time
#from tensorflow.python.layers import core as layers_core
#from tensorflow.contrib.seq2seq import ScheduledEmbeddingTrainingHelper
#from tensorflow.contrib.seq2seq import AttentionWrapper, AttentionWrapperState
#from tensorflow.python.ops import random_ops
#from tensorflow.python.ops import array_ops
#from tensorflow.python.ops import math_ops
#from tensorflow.python.ops import control_flow_ops
#from tensorflow.python.ops import check_ops
#from tensorflow.python.framework import ops
from tensorflow.python.framework import dtypes
from functools import partial
from preprocess_library import SentenceBatchGenerator, Word2Numb, bin_batch_create
import bisect
from tqdm import tqdm
import argparse
from threading import Thread


#Sets which GPU to use
gpu_num = 1
os.environ["CUDA_VISIBLE_DEVICES"]=str(gpu_num)

class Config(object):
    """The model configuration
    """

    PAD = 0
    EOS = 1
    SOS = 2
    UNK = 3

    def __init__(self,
                 src_model_path = None,
                 chan_model_path = None,
                 full_model_path = None,
                 channel={'type':'erasure','chan_param':0.90},
                 numb_epochs=10,
                 lr=0.001,
                 numb_tx_bits=400, # Length of output of old source encoder (after scale down) 
                 
                 chan_code_rate = 0.90,
                 enc_hidden_act = tf.nn.relu,
                 dec_hidden_act = tf.nn.relu, 
                 
                 load_mech = None, #Mechanism with which to load data. Set from input optional arguments at runtime 
                    
                 vocab_size=19158, # including special <PAD> <EOS> <SOS> <UNK>
                 embedding_size=200,
                 enc_hidden_units=256,
                 numb_enc_layers=2,
                 numb_dec_layers=2,
                 batch_size=512,
                 batch_size_test = 128,
                 length_from=4,
                 length_to=30,
                 bin_len = 4,
                 bits_per_bin = [200,250,300,350,400,450,500],
                 variable_encoding = None,
                 deep_encoding = False, #Never works
                 deep_encoding_params = [1000,800,600],
                 peephole = True,
                 dataset = 'euro',
                 w2n_path="../data/w2n_n2w_TopEuro.pickle",
#                 traindata_path="../../data/training_euro_wordlist20.pickle",
                 traindata_path="../data/corpora/europarl-v7.en/europarl-v7.en",
#                 testdata_path="../../data/testing_euro_wordlist20.pickle",
                 testdata_path="../data/corpora/europarl-v7.en/europarl-v7.en",
                 embed_path="../data/200_embed_large_TopEuro.pickle",
                 print_every = 1,
                 max_test_counter = int(1e6),
                 max_validate_counter = 10000,
                 max_batch_in_epoch = int(1e9),
                 save_every = int(1e5),
                 summary_every = 20,
                 qcap = 200,

                 chan_enc_layers = [4096, 2048, 1024, 512],
                 chan_dec_layers = [4096, 2048, 1600, 1024, 512],

                 **kwargs):
        """
        Args:
            src_model_path - path where the source encoder model is saved (trained on its own)
            chan_model_path - path where the channel coder network (pre-trained) is saved
            full_model_path - path where the entire network is saved (both parts run together)
            channel - dict with keys type [erasure,awgn] and chan_param
            bits_per_bin - ensure this is even and of the same size as the number of batches 
        """
        if isinstance(chan_enc_layers, list):
            self.enc_layers_dims = chan_enc_layers
            self.num_chan_enc_layers = len(chan_enc_layers)
        else:
            self.num_enc_layers = chan_enc_layers
            self.enc_layers_dims = None

        if isinstance(chan_dec_layers, list):
            self.dec_layers_dims = chan_dec_layers
            self.num_chan_dec_layers = len(chan_dec_layers)
        else:
            self.num_dec_layers = chan_dec_layers
            self.dec_layers_dims = None

        self.mod_chan_coding = True #Set in main func, but true by default
        
        self.load_mech = load_mech        

        self.enc_hidden_act = enc_hidden_act
        self.dec_hidden_act = dec_hidden_act
        self.src_model_path = src_model_path
        self.chan_model_path = chan_model_path
        self.full_model_path = full_model_path
        self.model_save_path = full_model_path #Default path where trained models are saved. Also set manually at main function(switched to src_model_path if src coder trained alone).

        self.epochs = numb_epochs
        self.lr = lr
        self.vocab_size = vocab_size
        self.embedding_size = embedding_size # length of embeddings
        self.enc_hidden_units = enc_hidden_units
        self.numb_enc_layers = numb_enc_layers
        self.numb_dec_layers = numb_dec_layers
        self.dec_hidden_units = enc_hidden_units * 2
        #batch properties
        self.batch_size = batch_size
        self.batch_size_test = batch_size_test
        self.length_from = length_from
        self.length_to = length_to
        self.bin_len = bin_len
        if not bits_per_bin:
            self.bits_per_bin = [numb_tx_bits for _ in range(length_from,length_to,bin_len)] 
        else:
            self.bits_per_bin = bits_per_bin
        self.variable_encoding = None #variable_encoding
        self.deep_encoding = deep_encoding
        self.deep_encoding_params = deep_encoding_params
        
        self.peephole = peephole
        self.channel = channel
        self.numb_tx_bits = numb_tx_bits
        self.chan_code_rate = chan_code_rate

        self.num_chan_bits = np.ceil(numb_tx_bits / chan_code_rate)

        self.w2n_path = w2n_path
        self.traindata_path = traindata_path
        self.testdata_path = testdata_path
        self.embed_path = embed_path
        self.dataset=dataset
        
        self.queue_limits = list(range(length_from,length_to,bin_len))
        self.print_every = print_every
        self.max_test_counter = max_test_counter
        self.max_batch_in_epoch = max_batch_in_epoch
        self.save_every = save_every
        self.summary_every = summary_every
        self.max_validate_counter = max_validate_counter
        
        self.qcap = qcap
        self.kwargs=kwargs
     
def generate_tb_filename(config):
    """Generate the file name for the model
    """
    tb_name = ""
    if config.channel["type"] == "none":
        tb_name = "Ch-" + config.channel["type"]
    else:
        tb_name = "Ch-" + config.channel["type"] + '{:0.2f}'.format(config.channel["chan_param"])
